import os.path
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import matplotx

mysigmalist = [1, 3]
myNlist = [50,100,200,500]
mylengthlist = [0.5, 1,2]
mysigmakernel = [1,2,4]

epsilon_all = np.array([1e-3, 2e-3, 3.3e-3, 5e-3, 6.7e-3, 8e-3, 1e-2, 2e-2, 3.3e-2, 5e-2, 6.7e-2, 8e-2,
                                1e-1,0.125,0.15,0.175, 2e-1, 3.3e-1, 6.7e-1, 8e-1, 1])
XBCF_continuous_result = [[None]*4,[None]*4, [None]*4]
XBCF_continuous_result[0][0] = np.load('Data/Continuous/XBCF_50_1_.npy')
XBCF_continuous_result[0][1] = np.load('Data/Continuous/XBCF_100_1_.npy')
XBCF_continuous_result[0][2] = np.load('Data/Continuous/XBCF_200_1_.npy')
XBCF_continuous_result[0][3] = np.load('Data/Continuous/XBCF_500_1_.npy')
XBCF_continuous_result[1][0] = np.load('Data/Continuous/XBCF_50_2_.npy')
XBCF_continuous_result[1][1] = np.load('Data/Continuous/XBCF_100_2_.npy')
XBCF_continuous_result[1][2] = np.load('Data/Continuous/XBCF_200_2_.npy')
XBCF_continuous_result[1][3] = np.load('Data/Continuous/XBCF_500_2_.npy')
XBCF_continuous_result[2][0] = np.load('Data/Continuous/XBCF_50_3_.npy')
XBCF_continuous_result[2][1] = np.load('Data/Continuous/XBCF_100_3_.npy')
XBCF_continuous_result[2][2] = np.load('Data/Continuous/XBCF_200_3_.npy')
XBCF_continuous_result[2][3] = np.load('Data/Continuous/XBCF_500_3_.npy')
XBCF_continuous_result_mean = [[None]*4,[None]*4, [None]*4]
XBCF_continuous_result_quantile = [[None]*4,[None]*4, [None]*4]
for i in range(3):
    for j in range(4):
        if XBCF_continuous_result[i][j] is not None:
            XBCF_continuous_result_mean[i][j] = XBCF_continuous_result[i][j].mean(axis=0)
            XBCF_continuous_result_quantile[i][j] = np.quantile(XBCF_continuous_result[i][j], q=[0.05,0.25, 0.5,0.75, 0.95], axis=0)


GP_continuous_result = [[[[None]*4,[None]*4] for s in range(3)] for t in range(3)]
GP_continuous_result_mean = [[[[None]*4,[None]*4] for s in range(3)] for t in range(3)]
GP_continuous_result_quantile = [[[[None]*4,[None]*4] for s in range(3)] for t in range(3)]

for g in range(3):
    for h in range(3):
        for i in range(2):
            for j in range(4):
                sigmakernel = mysigmakernel[g]
                length_scale = mylengthlist[h]
                sigma = mysigmalist[i]
                N = myNlist[j]
                mydataname = "_".join(["GP", str(N), str(sigma), str(length_scale),str(sigmakernel), '.npy'])
                if os.path.exists('Data/Continuous/' + mydataname):
                    GP_continuous_result[g][h][i][j] = np.load('Data/Continuous/' + mydataname)
                    GP_continuous_result_mean[g][h][i][j] = GP_continuous_result[g][h][i][j].mean(axis=0)
                    GP_continuous_result_quantile[g][h][i][j] = np.quantile(GP_continuous_result[g][h][i][j], q=[0.05,0.25, 0.5,0.75, 0.95], axis=0)

def align_yaxis(ax1, v1, ax2, v2):
    """adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
    _, y1 = ax1.transData.transform((0, v1))
    _, y2 = ax2.transData.transform((0, v2))
    inv = ax2.transData.inverted()
    _, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
    miny, maxy = ax2.get_ylim()
    ax2.set_ylim(miny+dy, maxy+dy)

def plot_continuous_xbcf_num1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.15)
    for i, j, temp in [(1, 0, 0), (1, 1, 1),(1,2,2), (1, 3, 3)]:
        if XBCF_continuous_result_mean[i][j] is not None:
            NName = ['50', '100', '200', '500'][j]
            mylinetype = ['dotted','solid','dashed','dashdot'][temp]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            ax1.plot(epsilon_all[:-2], XBCF_continuous_result_mean[i][j][0, :-2], c=mycolor, label = 'n='+NName, linestyle=mylinetype, linewidth=3)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different sample size (low signal-to-noise ratio)", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'XBCF_Continuous_num1'+'.pdf')

def plot_continuous_xbcf_num2():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(-0.016,0.041)
    for i, j, temp in [(1, 0, 0), (1, 1, 1),(1,2,2), (1, 3, 3)]:
        if XBCF_continuous_result_mean[i][j] is not None:
            NName = ['50', '100', '200', '500'][j]
            mylinetype = ['dotted','solid','dashed','dashdot'][temp]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            ax1.plot(epsilon_all[:-2], XBCF_continuous_result_mean[i][j][1, :-2], c=mycolor, label = 'n='+NName, linestyle=mylinetype, linewidth=3)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Value with different sample size (low signal-to-noise ratio)", fontsize=28)
    ax1.axhline(y=0, linestyle='--',color='black',linewidth=3, label='Baseline Policy')
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'XBCF_Continuous_num2'+'.pdf')

def plot_continuous_xbcf_signal1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.15)
    for i,j,temp in [(0,1,0),(1,1,1),(2,1,2)]:
        if XBCF_continuous_result_mean[i][j] is not None:
            SigmaName = ['High','Medium', 'Low'][i]
            mylinetype = ['dotted','solid','dashed','dashdot'][temp]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            ax1.plot(epsilon_all[:-2], XBCF_continuous_result_mean[i][j][0, :-2], c=mycolor, label = SigmaName, linestyle=mylinetype, linewidth=3)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different signal strength (n=100)", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'XBCF_Continuous_signal1'+'.pdf')

def plot_continuous_xbcf_signal2():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(-0.016, 0.041)
    for i, j, temp in [(0, 1, 0), (1, 1, 1),(2,1,2)]:
        if XBCF_continuous_result_mean[i][j] is not None:
            NName = ['50', '100', '200', '500'][j]
            SigmaName = ['High','Medium', 'Low'][i]
            mylinetype = ['dotted', 'solid', 'dashed', 'dashdot'][temp]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            ax1.plot(epsilon_all[:-2], XBCF_continuous_result_mean[i][j][1, :-2], c=mycolor, label=SigmaName,
                     linestyle=mylinetype, linewidth=3)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.axhline(y=0, linestyle='--', color='black', linewidth=3, label='Baseline Policy')
    ax1.set_title("Value with different signal strength (n=100)", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/' + 'XBCF_Continuous_signal2' + '.pdf')

plot_continuous_xbcf_num1()
plot_continuous_xbcf_num2()
plot_continuous_xbcf_signal1()
plot_continuous_xbcf_signal2()


def plot_continuous_gp_l1():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.209)
    for g, h, i, j, temp in [(1, 0, 1, 2, 0), (1, 1, 1, 2, 1), (1, 2, 1, 2, 2)]:
        if GP_continuous_result_mean[g][h][i][j] is not None:
            LengthName = ['l=0.5', 'l=1', 'l=2'][h]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_continuous_result_mean[g][h][i][j][0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=LengthName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different prior smoothness l ($\sigma_0=2)$", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'GP_Continuous_l1'+'.pdf')

def plot_continuous_gp_l2():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(-0.002,0.078)
    for g, h, i, j, temp in [(1, 0, 1, 2, 0), (1, 1, 1, 2, 1), (1, 2, 1, 2, 2)]:
        if GP_continuous_result_mean[g][h][i][j] is not None:
            LengthName = ['l=0.5', 'l=1', 'l=2'][h]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_continuous_result_mean[g][h][i][j][1, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=LengthName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Value with different prior smoothness l ($\sigma_0=2)$", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'GP_Continuous_l2'+'.pdf')

def plot_continuous_gp_sigma01():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average ACRisk of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(0,0.209)
    for g,h,i,j,temp in [(0,2,1,2,0),(1,2,1,2,1),(2,2,1,2,2)]:
        if GP_continuous_result_mean[g][h][i][j] is not None:
            PriorName = ['$\\frac{1}{\sigma_0}=1$', '$\\frac{1}{\sigma_0}=0.5$', '$\\frac{1}{\sigma_0}=0.25$'][g]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_continuous_result_mean[g][h][i][j][0, :-2], c=mycolor, linestyle=mylinetype, linewidth=3, label=PriorName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("ACRisk with different prior strength $\\frac{1}{\sigma_0}$ (l=2)", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'GP_Continuous_sigma01'+'.pdf')

def plot_continuous_gp_sigma02():
    fig1, ax1 = plt.subplots(figsize=(12, 8))
    ax1.set_ylabel('Average Value of the policy', fontsize=25)
    ax1.yaxis.set_tick_params(labelsize=20)
    ax1.xaxis.set_tick_params(labelsize=20)
    ax1.set_xlim(-0.0, 0.4)
    ax1.set_ylim(-0.002, 0.078)
    for g,h,i,j,temp in [(0,2,1,2,0),(1,2,1,2,1),(2,2,1,2,2)]:
        if GP_continuous_result_mean[g][h][i][j] is not None:
            PriorName = ['$\\frac{1}{\sigma_0}=1$', '$\\frac{1}{\sigma_0}=0.5$', '$\\frac{1}{\sigma_0}=0.25$'][g]
            mycolor = ['black', 'blue', 'green', 'red'][temp]
            mylinetype = ['solid', 'dotted', 'dashed', 'dashdot'][temp]
            ax1.plot(epsilon_all[:-2], GP_continuous_result_mean[g][h][i][j][1, :-2], c=mycolor, linestyle=mylinetype,
                     linewidth=3, label=PriorName)
    matplotx.line_labels(fontsize=20)  # line labels to the right
    ax1.axhline(y=0, linestyle='--', color='black', linewidth=3, label='Baseline Policy')
    ax1.set_xlabel('$\epsilon$ in the safety constraint', fontsize=28)
    ax1.set_title("Value with different prior strength $\\frac{1}{\sigma_0}$ ($l=2)$", fontsize=28)
    fig1.tight_layout()
    fig1.show()
    fig1.savefig('Figure2/'+'GP_Continuous_sigma02'+'.pdf')

plot_continuous_gp_l1()
plot_continuous_gp_l2()
plot_continuous_gp_sigma01()
plot_continuous_gp_sigma02()
